/* Array Library
   See array.Doc for a detailed description of the functions.*/

#ifndef _BOBCONSTANTS
#include bob:Constants
#endif

#ifndef _string_arrayfns
#define _string_arrayfns 1
#endif

form_matrix(m,n,v)
/* ----------------------------------------------------
   Returns an m x n matrix from a VECTOR v of size m*n,
   whose rows are successive sets of n elements of v.
   ----------------------------------------------------*/
{
  local i,j,s,u,V;
  u = newvector(n);
  V = newmatrix(m,n);
  s = sizeof(v);
  if (s != m*n) quit("form_matrix: data incompatible with order\n");
  for (i = 0; i < m; i++)
    {
      for (j = 0; j < n; j++)
        V[i][j] = v[i*n + j];
    }
  return V;
}

form_column(v)
/* ----------------------------------------
   Returns a column vector from a VECTOR v.
   ----------------------------------------*/
{
  local i,m,V;
  m = sizeof(v);
  V = newmatrix(m,1);
  for (i = 0; i < m; i++) V[i][0] = v[i]*1.0;
  return V;
}

form_row(v)
/* -------------------------------------
   Returns a row vector from a VECTOR v.
   -------------------------------------*/
{
  local j,n,V;
  n = sizeof(d);
  V = newmatrix(1,n);
  for (j = 0; j < n; j++) V[0][j] = d[j];
  return V;
}

printarray(A,l,p)
/* -------------------------------------------------------------------
   Prints the elements of an array (matrix or vector) A in successive
   rows, each with field width l and p digits after the decimal point.
   -------------------------------------------------------------------*/
{
  local i,j,m,n;
  m = sizeof(A);
  n = sizeof(A[0]);
  if (typeof(n) == INTEGER)
    {for (i = 0; i < m; i++)
      { for (j = 0; j < n; j++) print(format(A[i][j],l,p));
        print("\n");
      }
    }
  else
    {for (i = 0; i < m; i++) print(format(A[i],l,p));
     print("\n");
    }
}

format(n,l,p)
/*------------------------------------------------------------------------
 format(n,l,p) formats variables of types INTEGER, REAL and STRING for
 output.  It converts REAL or INTEGER numbers n to signed decimal strings,
 with p digits after the decimal point for reals (p is irrelevant for
 integers and strings), and returns them, or strings n, either padded with
 leading spaces to make their lengths up to l or, if l is too small (in
 particular if l = 0), at full length without any leading spaces.
 -------------------------------------------------------------------------*/
{
  local d,dec,i,int,neg,r,s;
  s = "";
  switch (typeof(n))
    {
      case REAL:
        if (neg = (n < 0.0)) n = -n;
        r = 0.5;
        for (i = 0; i < p; i++) r /= 10.0;
        if (n < r) { n = 0.0; neg = FALSE; } else n += r;        
        int = floor(n);
        if (neg) s = "-";
        s += int_10(int);
        s += ".";
        dec = n - 1.0*int; 
        for (i = 0; i < p; i++) dec *= 10.0;
        dec = floor(dec);
        d = "";
        for (i = 0;i < p; i++)
          { d  = '0' + dec%10 + d; dec /= 10; }
        s += d;
        break;
      case INTEGER:
        if (neg = (n < 0)) n = -n;
        if (neg) s = "-";
        s += int_10(n);
        break;
      case STRING:
        s = n;
        break;
      default:
        s = "format: 1st argument not real, integer or string";        
    }
  for (i = l - sizeof(s); i > 0; i --) s = " " + s;
  return s;
}

add(c,A)
/* ----------------------------------------------------
   Adds the constant c to every element of the array A.
   ----------------------------------------------------*/
{
  local a,f,i,j,m,n,t,ta,V;
  t = typeof(c);
  if ((t != INTEGER) && (t != REAL))
    quit("add: 1st argument is not a constant\n");
  if (typeof(A) != VECTOR) quit("add: 2nd argument is not an array\n");
  m = sizeof(A);
  n = sizeof(A[0]);
  a = c*A[0][0];
// quit("test stop\n");
  ta = typeof(a);
  if (ta == INTEGER) f = 1; else f = 1.0;
  c = f*c;
  if (typeof(n) == INTEGER)
    {
     V = newmatrix(m,n);
     for (i = 0; i < m; i++)
       for (j = 0; j < n; j++) V[i][j] = c + f*A[i][j];
    }
  else
    { V = newvector(m);
      for (j = 0; j < m; j++) V[i] = c + f*A[j];
    }
  return V;
}

sum(A,B)
/* --------------------------------------------------
   Returns the array sum A + B of the arrays A and B.
   --------------------------------------------------*/
{
  local i,j,k,ma,na,mb,nb,t,ta,tb,V;
  ta = typeof(A); tb = typeof(B);
  if (ta != VECTOR) quit("sum: 1st argument is not an array\n");
  if (tb != VECTOR) quit("sum: 2nd argument is not an array\n");
  ma = sizeof(A);
  mb = sizeof(B);
  na = sizeof(A[0]);
  nb = sizeof(B[0]);
  ta = typeof(na); tb = typeof(nb);
  if (ta != tb) quit("sum: arrays are not of same type\n");
  if (ta != INTEGER)
    { if (ma != mb) quit("sum: vectors are incompatible for addition\n");
      t = typeof(A[0]*B[0]);
      if (t == INTEGER) f = 1; else f = 1.0;
      V = newvector;
      for (j = 0; j <mb; j++) V[j] = f*A[j] + f*B[j];
    }
  else 
    { if ((ma != mb) || (na != nb)) quit("sum: matrices are incompatible for addition\n");
      t = typeof(A[0][0]*B[0][0]);
      if (t == INTEGER) f = 1; else f = 1.0;
      V = newmatrix(ma,na);
      for (i = 0; i < ma; i++)
        { for (j = 0; j < nb; j++) V[i][j] = f*A[i][j] + f*B[i][j];
        }
    }
  return V;
}

multiple(c,A)
/* -------------------------------------------------
   Returns the array A multiplied by the constant c.
   -------------------------------------------------*/
{
  local i,j,m,n,t,V;
  t = typeof(c);
  if ((t != INTEGER) && (t != REAL))
    quit("product: 1st argument is not a constant\n");
  if (typeof(A) != VECTOR) quit("product: 2nd argument is not an array\n");
  m = sizeof(A);
  n = sizeof(A[0]);
  if (typeof(n) == INTEGER)
    {
     V = newmatrix(m,n);
     for (i = 0; i < m; i++)
       { for (j = 0; j < n; j++) V[i][j] = c*A[i][j];
       }
    }
  else
    { V = newvector(m);
      for (i = 0; i < m; i++) V[i] = c*A[i];
    }
  return V;
}

product(A,B)
/* ------------------------------------------------------------------
   Returns the matrix product A.B of the array A (matrix or row vector)
   and the matrix B.
   ------------------------------------------------------------------*/
{
  local i,j,k,ma,na,mb,nb,u,v,V;
  if (typeof(A) != VECTOR) quit("product: 1st argument is not an array\n");
  if (typeof(B) != VECTOR) quit("product: 2nd argument is not an array\n");
  ma = sizeof(A);
  na = sizeof(A[0]);
  mb = sizeof(B);
  nb = sizeof(B[0]);
  if (typeof(nb) != INTEGER) quit("product: 2nd argument is not a matrix\n");
  if (typeof(na) != INTEGER) /* A is a (row) vector */
    { if (ma != mb) quit("product: arrays are incompatible for multiplication\n");
      V = newvector(ma);
      if (typeof(A[0]*B[0][0]) == REAL) v = 0.0; else v = 0;
      for (i = 0; i < ma; i++)
        { u = v;
          for (k = 0; k < ma; k++) u += A[k]*B[k][i];
          V[i] = u;
        }
    }
  else
    { if (na != mb) quit("product: arrays are incompatible for multiplication\n");
      V = newmatrix(ma,nb);
      if (typeof(A[0][0]*B[0][0]) == REAL) v = 0.0; else v = 0;
      for (i = 0; i < ma; i++)
        { for (j = 0; j < nb; j++)
            { u = v;
              for (k = 0; k < na; k++) u += A[i][k]*B[k][j];
              V[i][j] = u;
            }
        }
    }
  return V;
}

trace(A)
/* -----------------------------------------
   Returns the trace of the square matrix A.
   -----------------------------------------*/
{
  local i,m,n,t;
  if (typeof(A) != VECTOR) quit("trace: argument is not a matrix\n");
  m = sizeof(A);
  n = sizeof(A[0]);
  if (m != n) quit("trace: matrix is not square\n");
  if (typeof(A[0]) == REAL) t = 0.0; else t = 0;
  for (i = 0; i < m; i++) t += A[i][i];
  return t;
}

transpose(A)
/* --------------------------------------
   Returns the transpose of the matrix A.
   --------------------------------------*/
{
  local i,j,m,n,V;
  if (typeof(A) != VECTOR) quit("transpose: argument is not a matrix\n");
  m = sizeof(A);
  n = sizeof(A[0]);
  if (typeof(n) != INTEGER) quit("transpose: argument is not a matrix\n");
  V = newmatrix(n,m);
  for (i = 0; i < m; i++)
    for (j = 0; j < n; j++) V[j][i] = A[i][j];
  return V;
}

inverse(A)
/* ------------------------------------------------------
   Returns the inverse of a non-singular square matrix A.
   ------------------------------------------------------*/
{
  local d,i,j,k,l,m,n,ncol,norm,p,q,s,tempv,tol,V;
  if (typeof(A) != VECTOR) quit("inverse: argument is not a matrix\n");
  m = sizeof(A);
  n = sizeof(A[0]);
  if (typeof(n) != INTEGER) quit("inverse: argument is not a matrix\n");
  if (m != n) quit("inverse: matrix is not square\n");
  
  /* --- form augmented matrix U =(A,I) ---  */ 
  ncol = 2*n;
  U = newmatrix(m,ncol);
  for (i = 0; i < m; i++)
    { for (j = 0; j < n; j++) U[i][j] = 1.0*A[i][j];
      for (j = n; j < ncol; j++) U[i][j] = 0.0;
      U[i][i+n] = 1.0;
    }
  /* --- set tolerance --- */ 
  norm = 0.0;
  for (i = 0 ;i < m; i++)
    { for (j = 0 ;j < n; j++) norm += abs(U[i][j]);
    }
  if (norm == 0.0) norm = 1.0;
  tol = norm/1000000.0;

  d = 1.0; /* --- initialise determinant --- */
  tempv = newvector(ncol); /* --- for row interchange --- */

  /* --- row operations on U to reduce A to up.tr form --- */
  for (i = 0; i < m; i++)
    { /* --- find pivot --- */
      s = 0.0;
      for (j = i; j < m; j++)
        { if (abs(U[j][i]) > s)
            { s = abs(U[j][i]); e=j; }
        }
      if (e != i)
        { /* --- interchange rows i & e --- */
          tempv = U[e]; U[e] = U[i]; U[i] = tempv;
          d = -d;
        }

      d = d*U[i][i]; if (abs(d) < tol) quit("inverse: matrix is singular\n");

      /* --- make unit diagonal --- */
      q = 1.0 / U[i][i];
      for (j=i ;j < ncol; j++) U[i][j] = q*U[i][j];

      /* --- form up.tr A --- */
      for (k = i+1 ;k < m; k++)
        { p = U[k][i];
          for (l=i; l < ncol; l++)
            { temp = U[k][l] - p*U[i][l]; U[k][l] = temp;
              if (abs(U[k][l]) < tol) U[k][l]=0.0;
            }
        }
    }

  d=d * U[m-1][m-1]; /* --- d is the value of det A --- */
  if (abs(d) < tol) quit("inverse: matrix is singular\n");
  
  if (m > 1)
    /* --- reduce U to (I,inv(A)) --- */
    { for (l = m; l< ncol; l++)
       {
         for (i=m-1 ; i >= 0; i--)
           {s = U[i][l];
             for (j=i+1 ; j < n; j++) s=s-U[i][j]*U[j][l];
             U[i][l] = s;
           }
       }    
    }
    
  V = newmatrix(m,n);
  /* --- transfer inverse to V --- */
  for (i = 0; i < m; i++)
    { for (j = 0; j < n; j++) V[i][j] = U[i][j+n];
    }
    
  return V;
}

concat(A,B)
/* -------------------------------------------------------------------
   Returns the array concatenation A + B of the STRING arrays A and B.
   -------------------------------------------------------------------*/
{
  local i,j,k,ma,na,mb,nb,t,ta,tb,V;
  ta = typeof(A); tb = typeof(B);
  if (ta != VECTOR) quit("concat: 1st argument is not an array\n");
  if (tb != VECTOR) quit("concat: 2nd argument is not an array\n");
  ma = sizeof(A);
  mb = sizeof(B);
  na = sizeof(A[0]);
  nb = sizeof(B[0]);
  ta = typeof(na); tb = typeof(nb);
  if (ta != tb) quit("concat: arrays are not of same type\n");
  if (ta != INTEGER)
    { if (ma != mb) quit("concat: vectors are incompatible for concatination\n");
      t = typeof(A[0]);
      if (t != STRING)  quit("concat: 1st argument is not of type STRING\n");
      t = typeof(B[0]);
      if (t != STRING)  quit("concat: 2nd argument is not of type STRING\n");
      V = newvector;
      for (j = 0; j <mb; j++) V[j] = A[j] + B[j];
    }
  else 
    { if ((ma != mb) || (na != nb)) quit("concat: matrices are incompatible for concatination\n");
      t = typeof(A[0][0]);
      if (t != STRING)  quit("concat: 1st argument is not of type STRING\n");
      t = typeof(B[0][0]);
      if (t != STRING)  quit("concat: 2nd argument is not of type STRING\n");
      V = newmatrix(ma,na);
      for (i = 0; i < ma; i++)
        { for (j = 0; j < nb; j++) V[i][j] = A[i][j] + B[i][j];
        }
    }
  return V;
}


newmatrix(m,n)
/* -----------------------------------------------------
   Returns a new m x n matrix with elements of type NIL.
   -----------------------------------------------------*/
{
  local i,V;
  V = newvector(m);
  for(i = 0; i < m; i++)
    V[i] = newvector(n);
  return V;
}

abs(n)
{ local m;
  if (n < 0.0) m = -n; else m = n;
  return m;
}

 
int_10(n)
/* -------------------------------------
   Returns decimal string for integer n.
   -------------------------------------*/
{
  local s;
  s = "";
  if (n == 0) s = "0";
  else while(n) { s = '0' + n%10 + s; n /= 10; }
  return s;
}

